Conversation
…g and refined test case generation for various configurations. - Cleaned up unused variables and improved code readability in the FSDPAGTensor class by removing unnecessary parameters.
… FusedAdam. Added debug print for DTensor in MultiTensorApply.
… tolerances for tensor comparisons. Updated test logic to accommodate new tolerance parameters for improved accuracy in floating-point comparisons.
…l differences in gradient calculations. Clean up unused debug print statements in MultiTensorApply and ensure proper newline at the end of the FSDPAGTensor serialization method.
| if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: | ||
| quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index] | ||
| if not isinstance(quantizer, MXFP8Quantizer): | ||
| quantizer.set_usage(columnwise=False) |
There was a problem hiding this comment.
For FSDP2 with FP8, keep_fp8_weight_transpose_cache should be False. Caching the transposed weight would imply an all-gather of the transposed tensor as well, increasing memory and communication and negating the advantages of FSDP2’s sharded parameter layout.
| data = torch.zeros_like(param, dtype=torch.int16) | ||
| else: | ||
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | ||
| data = torch.empty_like(param, dtype=dtype) |
There was a problem hiding this comment.
When using FSDP2, parameters are DTensors, and when we do torch.zeros() or torch.empty() we create regular pytorch Tensors.
This was causing
[rank1]: RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
[rank7]: File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 422, in initialize_state
[rank7]: self.set_scaled_state(param, "master_param", param.clone().detach().float())
[rank7]: File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 363, in set_scaled_state
[rank7]: state[state_name].copy_(unscaled_state)
Fix:
Keep optimizer state consistent with the parameter type: when parameters are DTensors, state should be DTensors as well. Using torch.empty_like(param, ...) (and the same idea for other state buffers) creates state as a DTensor with the same placement as param, so both sides of copy_ are DTensors and the error is avoided.
There was a problem hiding this comment.
Is it upstream fix cherry-picking?
There was a problem hiding this comment.
Upstream fixes this in TEv2.12, along with few other fixes.
NVIDIA/TransformerEngine@fe8fad5#diff-0801a8d92a56d458946da1439b62e0add1613b7da83d31bc218a852b6b9e42b1
This wasn't cherry picked.
…by adding a newline character after the pass statement in the test_dummy function.
|
|
||
| # Zero the parameter gradients | ||
| optimizer.zero_grad() | ||
| with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): |
There was a problem hiding this comment.
Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?
There was a problem hiding this comment.
It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.
There was a problem hiding this comment.
So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?
There was a problem hiding this comment.
Yes, it should. I'll make the changes.
| assert len(l1) == len(l2), "Unequal number of outputs." | ||
| for i, (t1, t2) in enumerate(zip(l1, l2)): | ||
| result = torch.allclose(t1, t2, atol=0, rtol=0) | ||
| tols = dict(atol=atol) |
There was a problem hiding this comment.
Move tolls calculation out of the loop
…s for improved clarity and consistency.
Manually ported fix from upstream commit 139c863 The full commit was not cherry-picked due to unrelated changes across many files. Addressed PR comments
Description
This PR adds unit test covering different configurations such as:
All the unit tests compare FSDP2 vs DDP grads/output.
This PR also cleans up fsdp2_all_gather_tensor to match upstream's methods.
This PR also fixes issue with fused_adam when using it with FSDP2.
Fixes # (https://github.com/ROCm/frameworks-internal/issues/15291)
Type of change
Checklist: